{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reducing warm-up time for compilation"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/portable_compilation.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we will walk you through how to use the `pruna` package to [compile your model in a way that reduces warm-up time significantly](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) when re-loading the model on a new machine. Please be aware that as of now, this tutorial will only apply to re-loading the model on a new machine with identical hardware as the machine it was compiled on. The provided inference and compilation times were measured on an NVIDIA H100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
    "# the following command will install the latest version of pruna\n",
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0. Setup\n",
    "\n",
    "In a first step, we will do a brief setup to mimic the loading of a compiled model on a new machine. To do so, we will specifically set the torch inductor cache s.t. we can delete it later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "cache_dir = \"temp_cache_dir/\"\n",
    "os.environ[\"TORCHINDUCTOR_CACHE_DIR\"] = cache_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Load the model\n",
    "\n",
    "We are now ready to load the model we want to compile. In this case, we will use a stable diffusion pipeline to both apply caching and compilation to to showcase the support of portable compilation with other algorithms in `pruna`. Of course, only compiling with `torch_compile` is also supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "\n",
    "pipe = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16)\n",
    "pipe = pipe.to(\"cuda\")\n",
    "\n",
    "prompt = \"a photo of an astronaut riding a horse on mars\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Smash the model\n",
    "\n",
    "Next, we define the Smashconfig and smash the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import SmashConfig, smash\n",
    "\n",
    "smash_config = SmashConfig({\n",
    "    \"deepcache\": {},\n",
    "    \"torch_compile\": {\"make_portable\": True}\n",
    "})\n",
    "pipe = smash(pipe, smash_config=smash_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Run and save the compiled model\n",
    "\n",
    "We now run the model and observe both the time it takes for the first warm-up inference, in this example approximately 50 seconds. In the subsequent runs, we can then see the runtime of the compiled model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "for _ in range(2):\n",
    "    start = time.time()\n",
    "    pipe(prompt)\n",
    "    print(f\"Time taken: {time.time() - start} seconds\")\n",
    "\n",
    "pipe.save_pretrained(\"smashed_model/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Simulate move to a new machine\n",
    "\n",
    "Next, we will delete the compilation cache directory to mimic moving to a new machine. After that, please **restart your kernel or process and continue the tutorial**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "shutil.rmtree(cache_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Load the model\n",
    "We can now load the model and check that the warm-up time has significantly reduced!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import torch\n",
    "\n",
    "from pruna import PrunaModel, SmashConfig, smash\n",
    "\n",
    "pipe = PrunaModel.from_pretrained(\"smashed_model\", torch_dtype=torch.float16)\n",
    "prompt = \"a photo of an astronaut riding a horse on mars\"\n",
    "\n",
    "for _ in range(2):\n",
    "    start = time.time()\n",
    "    pipe(prompt)\n",
    "    print(f\"Time taken: {time.time() - start} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wrap Up\n",
    "\n",
    "Congratulations! You have successfully smashed a model with portable compilation. The only parts that you should modify are step 1 and step 2 to fit your use case."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
